{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import DINOLoss\n",
    "from lightly.models.modules import DINOProjectionHead\n",
    "from lightly.models.utils import deactivate_requires_grad, update_momentum\n",
    "from lightly.transforms.dino_transform import DINOTransform\n",
    "from lightly.utils.scheduler import cosine_schedule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DINO(torch.nn.Module):\n",
    "    def __init__(self, backbone, input_dim):\n",
    "        super().__init__()\n",
    "        self.student_backbone = backbone\n",
    "        self.student_head = DINOProjectionHead(\n",
    "            input_dim, 512, 64, 2048, freeze_last_layer=1\n",
    "        )\n",
    "        self.teacher_backbone = copy.deepcopy(backbone)\n",
    "        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)\n",
    "        deactivate_requires_grad(self.teacher_backbone)\n",
    "        deactivate_requires_grad(self.teacher_head)\n",
    "\n",
    "    def forward(self, x):\n",
    "        y = self.student_backbone(x).flatten(start_dim=1)\n",
    "        z = self.student_head(y)\n",
    "        return z\n",
    "\n",
    "    def forward_teacher(self, x):\n",
    "        y = self.teacher_backbone(x).flatten(start_dim=1)\n",
    "        z = self.teacher_head(y)\n",
    "        return z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = torchvision.models.resnet18()\n",
    "backbone = nn.Sequential(*list(resnet.children())[:-1])\n",
    "input_dim = 512\n",
    "# instead of a resnet you can also use a vision transformer backbone as in the\n",
    "# original paper (you might have to reduce the batch size in this case):\n",
    "# backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)\n",
    "# input_dim = backbone.embed_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DINO(backbone, input_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = DINOTransform()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# we ignore object detection annotations by setting target_transform to return 0\n",
    "def target_transform(t):\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torchvision.datasets.VOCDetection(\n",
    "    \"datasets/pascal_voc\",\n",
    "    download=True,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = DINOLoss(\n",
    "    output_dim=2048,\n",
    "    warmup_teacher_temp_epochs=5,\n",
    ")\n",
    "# move loss to correct device because it also contains parameters\n",
    "criterion = criterion.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(epochs):\n",
    "    total_loss = 0\n",
    "    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)\n",
    "    for batch in dataloader:\n",
    "        views = batch[0]\n",
    "        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)\n",
    "        update_momentum(model.student_head, model.teacher_head, m=momentum_val)\n",
    "        views = [view.to(device) for view in views]\n",
    "        global_views = views[:2]\n",
    "        teacher_out = [model.forward_teacher(view) for view in global_views]\n",
    "        student_out = [model.forward(view) for view in views]\n",
    "        loss = criterion(teacher_out, student_out, epoch=epoch)\n",
    "        total_loss += loss.detach()\n",
    "        loss.backward()\n",
    "        # We only cancel gradients of student head.\n",
    "        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
